package com.rtsapp.naval.reloadagent.agent;

import com.rtsapp.naval.reloadagent.filemonitor.FileEvent;
import com.rtsapp.naval.reloadagent.filemonitor.FileModifiedListener;
import com.rtsapp.naval.reloadagent.filemonitor.FileMonitor;
import com.rtsapp.naval.reloadagent.filemonitor.JarEvent;
import com.rtsapp.naval.reloadagent.filemonitor.JarModifiedListener;
import com.rtsapp.naval.reloadagent.filemonitor.JarMonitor;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.instrument.ClassDefinition;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.logging.Level;
import java.util.logging.Logger;


/**
 * 一个监控重加载类
 */
public class AgentMonitor implements FileModifiedListener, JarModifiedListener {


    private static final Logger log = Logger.getLogger(AgentMonitor.class.getName());

    private final Instrumentation inst;
    private static final int MONITOR_PERIOD_MIN_VALUE = 1000;
    private final List<String> classFolders;
    private final String jarFolder;
    private final ScheduledExecutorService service;
    private HashMap<String, Class<?>> loadedClassesMap;


    public AgentMonitor(Instrumentation inst, AgentArgs args) {
        this.inst = inst;
        this.classFolders = args.getClassFolders();
        this.jarFolder = args.getJarFolder();
        int monitorPeriod = MONITOR_PERIOD_MIN_VALUE;
        if (args.getPeriod() > monitorPeriod) {
            monitorPeriod = args.getPeriod();
        }

        log.setLevel(args.getLogLevel());

        service = Executors.newScheduledThreadPool(2);

        FileMonitor fileMonitor = new FileMonitor(classFolders, "class");
        fileMonitor.addModifiedListener(this);
        service.scheduleWithFixedDelay(fileMonitor, 0, monitorPeriod, TimeUnit.MILLISECONDS);

        if (jarFolder != null) {
            JarMonitor jarMonitor = new JarMonitor(jarFolder);
            jarMonitor.addJarModifiedListener(this);
            service.scheduleWithFixedDelay(jarMonitor, 0, monitorPeriod,
                    TimeUnit.MILLISECONDS);
        }

        log.info("AgentMonitor: watching class folders: " + classFolders);
        log.info("AgentMonitor: watching jars folder: " + jarFolder);
        log.info("AgentMonitor: period between checks (ms): " + monitorPeriod);
        log.info("AgentMonitor: log level: " + log.getLevel());

    }



    public void stop() {
        service.shutdown();
    }


    public void fileModified(FileEvent event) {
        File classFile = event.getFile();
        String className = toClassName(event.getFolder().toString(), classFile.toString());
        try {
            byte[] classBytes = toByteArray(new FileInputStream(classFile));
            redefineClass(className, classBytes);
        } catch (Exception e) {
            log.log(Level.SEVERE, "fileModified", e);
        }
    }


    public void jarModified(JarEvent event) {

        String className = toClassName(null, event.getClassName());
        System.out.println( "jarModified" + className );
        JarFile jar = event.getJarFile();
        try {
            byte[] classBytes = toByteArray(
                    jar.getInputStream(
                            getJarEntry(jar, event.getClassName())));
            redefineClass(className, classBytes);
        } catch (Exception e) {
            log.log(Level.SEVERE, "jarModified", e);
        }
    }

    /**
     *  JVM重新加载类的字节码
     */
    protected void redefineClass(String className, byte[] classBytes)
      throws ClassNotFoundException, UnmodifiableClassException
    {
        HashMap<String, Class<?>> loadedClassesMap = getLoadedClassesMap();
        
        Class<?> clazz = loadedClassesMap.get(className);
        if (clazz != null) {

        	System.out.println( "redefineClass clazz" + clazz.toString() );
            ClassDefinition definition = new ClassDefinition(clazz, classBytes);
            inst.redefineClasses(new ClassDefinition[] { definition });

           
            log.info("Redefined " + className);

        }else{
        	System.out.println( "redefineClass clazz not find " + className );
        	
            clazz = this.getClass().getClassLoader().loadClass(className);
            ClassDefinition definition = new ClassDefinition(clazz, classBytes);
            inst.redefineClasses(new ClassDefinition[] { definition });
        }
    }

    private HashMap<String, Class<?>> getLoadedClassesMap() {
        if (loadedClassesMap == null) {
            loadedClassesMap = new HashMap<String, Class<?>>();
            Class<?>[] loadedClasses = inst.getAllLoadedClasses();
            for (Class<?> clazz : loadedClasses) {
                loadedClassesMap.put(clazz.getName(), clazz);
            }
        }
        return loadedClassesMap;
    }


    private static String toClassName( String baseFolder, String fullPath ) {
        String className = fullPath;
        if (baseFolder != null) {
            className = className.substring(baseFolder.length() + 1);
        }
        return className.replace(".class", "").replace(File.separatorChar, '.');
    }

    /**
     * 从一个jar包中获取指定名称的实体
     */
    private static JarEntry getJarEntry(JarFile jar, String entryName ) {
        JarEntry entry = null;
        for (Enumeration<JarEntry> entries = jar.entries(); entries
                .hasMoreElements();) {
            entry = entries.nextElement();
            if (entry.getName().equals(entryName)) {
                return entry;
            }
        }
        throw new IllegalArgumentException("EntryName " + entryName + " does not exist in jar " + jar);
    }

    /**
     * 加载.class到byte数组
     */
    private static byte[] toByteArray(InputStream is) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();

        byte[] buffer = new byte[1024];
        int bytesRead = 0;
        while ((bytesRead = is.read(buffer)) != -1) {
            byte[] tmp = new byte[bytesRead];
            System.arraycopy(buffer, 0, tmp, 0, bytesRead);
            baos.write(tmp);
        }

        byte[] result = baos.toByteArray();

        baos.close();
        is.close();

        return result;
    }
    
    
}
